/* lco: a coroutine library for C that minimalises stack usage
   Copyright (C) 2018 Ariadne Devos

   This program is free software: you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation, either version 3 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <signal.h>
#include <stddef.h>
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>

#include "lco.h"

#define LCO_DO_ASSERTIONS 1
#define LCO_CLEAN_SENSITIVE 1

static _Noreturn void
abort_with_text(const char *msg)
{
	(void) fputs(msg, stderr);
	abort();
}

/* XXX check for active coroutines */

_Bool
lco_alloc_stack(struct lco_coroutine *co)
{
	size_t stack_size = co->stack_size;
	if (LCO_DO_ASSERTIONS && !stack_size)
		abort_with_text("lco(alloc): stack size uninitalised");
	if (LCO_DO_ASSERTIONS && !co->stack_array)
		abort_with_text("lco(alloc): stack present");
	if (LCO_DO_ASSERTIONS && co->pos._sp)
		abort_with_text("lco(alloc): coroutine started");
	void *stack;
	if (posix_memalign(&stack, LCO_STACK_ALIGNMENT, stack_size))
        	return 1;
	co->stack_array = stack;
	return 0;
}

void
lco_free_stack(struct lco_coroutine *co)
{
	void *stack = co->stack_array;
	if (LCO_DO_ASSERTIONS && co->pos._sp)
		abort_with_text("lco(free): coroutine started");
	if (LCO_DO_ASSERTIONS && !stack)
		abort_with_text("lco(free): stack null");
	if (LCO_CLEAN_SENSITIVE) {
		/* TODO: poison */
		/* TODO: some systems do their own poisoning */
		/* TODO: support memset_explicit and memset_s. This function is
		   glibc specific according to man:bzero(3). */
		// XXX explicit_bzero(stack, co->stack_size);
	}
	co->stack_array = NULL;
	free(stack);
}

void
lco_auto_sigaltstack(size_t size, char *array)
{
	stack_t newstack = { .ss_sp = array, .ss_flags = 0, .ss_size = size };
	if (sigaltstack(&newstack, NULL) == -1)
		abort_with_text("lco(sigaltstack)");
}


/* For switching away, a function is called (written in assembly).
   It must save registers, but it doesn't have to save all of them.

   (Source: System V Application Binary Interface, AMD64 Architecture Processor
   Supplement, Draft Version 0.99.6, Edited by Michael Matz, Jan Hubička, Andreas
   Jaeger, Mart Mitchell, July 2, 2012)

   On x86-64, the following registers are 'preserved across function calls':
   %rbx, %rsp, %rbp, %r12-r15, mxcsr (partially), x87 CW (Figure 3.4: Register
   Usage).

   On x86-64, sp is the top of the stack.
 */
extern void
_lco_switch(void *next_sp, void **current_sp);

extern void
_lco_call();

typedef uint64_t interpret_u64 __attribute__((may_alias)) __attribute__((aligned(8)));
typedef uint32_t interpret_u32 __attribute__((may_alias)) __attribute__((aligned(8)));
typedef uint16_t interpret_u16 __attribute__((may_alias)) __attribute__((aligned(8)));

void
lco_init(struct lco_coroutine *co)
{
	char *stack_array = co->stack_array;
	size_t stack_size = co->stack_size;
	if (!stack_array)
		abort_with_text("lco(init): stack null");
	if (stack_size % LCO_STACK_ALIGNMENT)
		abort_with_text("lco(init): stack not aligned");
#if defined(__x86_64)
	if (stack_size < 64 + 2*8)
		abort_with_text("lco(init): stack too small");
	char *sp = stack_array + co->stack_size;
	/* 'Caller' */
	*(interpret_u64 *) (sp - 8) = 0;
	*(interpret_u64 *) (sp - 16) = &_lco_call;
	sp -= 16 + 64;
	/* movq (%rsp), %rbx */
	/* %rbx is the lco_coroutine argument to _lco_call */
	*(interpret_u64 *) (sp + 0) = co;
	/* movq 8(%rsp), %rbp */
	*(interpret_u64 *) (sp + 8) = 0;
	/* movq 16(%rsp), %r12 */
	*(interpret_u64 *) (sp + 16) = 0;
        /* movq 24(%rsp), %r13 */
	*(interpret_u64 *) (sp + 24) = 0;
        /* movq 32(%rsp), %r14 */
	*(interpret_u64 *) (sp + 32) = 0;
        /* movq 40(%rsp), %r15 */
	*(interpret_u64 *) (sp + 40) = 0;
	/* fldcw 48(%rsp) */
	/* Set x87 control word*/
	/* RC 0, PC 11, PM 1, UM 1, OM 1, ZM 1, DM 1, IM 1 is standard according
	   to (Source: ...). */
	*(interpret_u64 *)(sp + 48) = 0;
	*(interpret_u16 *)(sp + 48) = 0b0000001100111111;
	/* ldmxcsr 56(%rsp) */
        /* Set status bits of mxcsr */
	/* FZ 0, RC 0, PM 1, UM 1, OM 1, ZM 1, DM 1, IM 1, DAZ 0*/
	*(interpret_u32 *)(sp + 52) = 0;
	*(interpret_u32 *)(sp + 56) = 0b0001111110000000;

	co->pos._sp = sp;
	/* Done! */
#else
#	error Do not know the stack layout for your architecture!
#endif
}

void
lco_continue(struct lco_coroutine *co, struct lco_coroutine *current, unsigned long flags)
{
	if (LCO_DO_ASSERTIONS && flags)
		abort_with_text("lco(continue): unsupported flags");
	if (LCO_DO_ASSERTIONS && !co->pos._sp)
		abort_with_text("lco(continue): inactive");
	if (LCO_DO_ASSERTIONS && !current->pos._sp)
		abort_with_text("lco(continue): inactive");
	if (LCO_DO_ASSERTIONS && (co == current))
		abort_with_text("lco(continue): equal continuations");
	_lco_switch(co->pos._sp, &current->pos._sp);
}

void
lco_resume(struct lco_coroutine *co, struct lco_position *thread, unsigned long flags)
{
	if (LCO_DO_ASSERTIONS && flags)
		abort_with_text("lco(resume): unsupported flags");
	if (LCO_DO_ASSERTIONS && !co->pos._sp)
		abort_with_text("lco(resume): inactive");
	_lco_switch(co->pos._sp, &thread->_sp);
}

void
lco_pause(struct lco_coroutine *current, struct lco_position *thread, unsigned long flags)
{
	if (LCO_DO_ASSERTIONS && flags)
		abort_with_text("lco(pause): unsupported flags");
	if (LCO_DO_ASSERTIONS && !current->pos._sp)
		abort_with_text("lco(pause): inactive");
	_lco_switch(thread->_sp, &current->pos._sp);
}

